package ffte.algorithm

import breeze.math.Complex
import ffte.algorithm.FFTMisc

object circularOrder {
	def factor(k:Int, n:Int) = Complex(Math.cos(2*Math.PI*k/n),Math.sin(2*Math.PI*k/n))
    def circularFilter(x:List[Complex],y:List[Complex]) = {
        val N = x.length
        (0 until N).map { case (n) =>
            (0 until N).map { case (k) => 
                x(k)*y((N+n-k)%N)
            }.sum
        }.toList
    }
    def circularXbyDFT(x:List[Complex],y:List[Complex]) = {
        val fx = FFTMisc.dft(x)
        val fy = FFTMisc.dft(y)
        val n = fx.length
        FFTMisc.dft(fx.zip(fy).map{ case (a,b) => a*b.conjugate/n})
    }
}

/**
 * Circular order permutation for Winograd algorithms
 *
 * This class implements the circular ordering permutation required by
 * Winograd algorithms. The circular order provides an optimal permutation
 * of input/output indices that minimizes the computational complexity
 * of Winograd minimal multiplication algorithms.
 *
 * Mathematical Foundation:
 * For a prime p, the circular order is based on a primitive root modulo p.
 * This ordering ensures that the cyclic groups in Winograd algorithms
 * have optimal structure for minimal multiplication count.
 *
 * Key Properties:
 * - Generates optimal permutation for Winograd algorithms
 * - Provides both forward and reverse permutation tables
 * - Works for any prime number transform size
 * - Essential for Winograd minimal multiplication algorithms
 *
 * @param p Prime number transform size
 */
case class circularOrder(p: Int) {
    /**
     * Forward permutation table
     *
     * Generates the permutation table that maps natural order
     * indices to circular order indices. This is used for
     * optimal Winograd algorithm structure.
     *
     * @return Array of circular order indices
     */
    def tab = {
        val r = root(0)
        var m = 1
        for (i <- 0 until p) yield
            if (i == 0) 0 else {
                val n = m
                m *= r
                m %= p
                n
            }
    }.toArray

    /**
     * Reverse permutation table
     *
     * Generates the inverse permutation table that maps circular
     * order indices back to natural order indices.
     *
     * @return Array of natural order indices
     */
    def rtab = {
        val r = new Array[Int](tab.length)
        for (i <- 0 until p) if (i == 0) r(0) = tab(0) else r(tab(i)) = i
        r
    }.toArray

    /**
     * Primitive root finder
     *
     * Finds a primitive root modulo p, which is a number whose powers
     * generate all non-zero residues modulo p. This is essential for
     * constructing the circular order permutation.
     *
     * @return Array containing primitive roots of p
     */
    def root = {
        if (p == 3) Array(2) else {
            val f = Prime.repesent(p - 1)
            (2 until p - 1).filter { x => !f.keys.map { a => BigInt(x).modPow(BigInt(p - 1) / a, BigInt(p)) == BigInt(1) }.reduce(_ | _) }.toArray
        }
    }

    /**
     * Get circular order of index
     *
     * Returns the circular order position for a given natural index.
     *
     * @param n Natural index (0 to p-1)
     * @return Circular order index
     */
    def order(n: Int) = tab(n)

    /**
     * Convert list to circular order
     *
     * Reorders a list according to the circular order permutation.
     *
     * @param x Input list in natural order
     * @tparam T List element type
     * @return List reordered in circular order
     */
    def to[T](x: List[T]) = (0 until x.length).map { n => x(tab(n)) }.toList

    /**
     * Convert list from circular order
     *
     * Reorders a list from circular order back to natural order.
     *
     * @param x Input list in circular order
     * @tparam T List element type
     * @return List reordered in natural order
     */
    def from[T](x: List[T]) = (0 until x.length).map { n => x(rtab(n)) }.toList
    def dft(x:List[Complex]) = {
        val N  = x.length
        val f  = (0 until N).map{x => circularOrder.factor(x,N)}.toList
        val rf = to(f)
        val rx = to(x)
        val fout = circularOrder.circularXbyDFT(rx.tail,rf.tail).map(x(0)+_)
        val r  = x.sum :: fout
        from(r) 
    }
    def wcoeff = {
        val f  = (0 until p).map{x => circularOrder.factor(x,p)}.toList
        val rf = to(f)
        val rft = FFTMisc.dft(rf.tail).map(_.conjugate)
        rft.map(_/(p-1))
    }
    def wfft(dW:Int) = {
        val ff = wcoeff
        val maxff = ff.map(_.abs).max
        val s = 2*dW-FFTMisc.log2Up(Math.round(maxff*(1<<dW)).toInt)-1
        (ff.map{ x => 
            val a = x*(1<<s)
            val _sat = (1<<(dW-1))-1 
            def sat(x:Int) = if(x>=_sat) _sat 
                else if(x <= (-_sat)) -_sat
                else x
            (sat(Math.round(a.re).toInt),sat(Math.round(a.im).toInt)) 
        },s-dW+1)
    }
    override def toString = s"co($p):" + tab.map(_.toString).reduce(_+","+_) +"|"+ rtab.map(_.toString).reduce(_+","+_) 
}